{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据集\n",
    "def loadDataSet():\n",
    "    postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n",
    "                  ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n",
    "                  ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n",
    "                  ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n",
    "                  ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n",
    "                  ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n",
    "    classVec = [0, 1, 0, 1, 0, 1]  # 1代表侮辱性文字，0代表正常言论\n",
    "    return postingList, classVec\n",
    "\n",
    "# 创建不重复数据集词列表\n",
    "def createVocabList(dataSet):\n",
    "    vocabSet = set([])\n",
    "    for document in dataSet:\n",
    "        vocabSet = vocabSet | set(document)  # | 表示并集\n",
    "    return list(vocabSet)\n",
    "\n",
    "# 返回输入数据的词向量\n",
    "def setOfWords2Vec(vocabList, inputSet):\n",
    "    returnVec = [0] * len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] = 1\n",
    "        else:\n",
    "            print('the word: %s is not in my Vocabulary!' % word)\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['posting',\n",
       " 'mr',\n",
       " 'worthless',\n",
       " 'dalmation',\n",
       " 'stop',\n",
       " 'maybe',\n",
       " 'garbage',\n",
       " 'how',\n",
       " 'to',\n",
       " 'cute',\n",
       " 'stupid',\n",
       " 'problems',\n",
       " 'is',\n",
       " 'my',\n",
       " 'quit',\n",
       " 'help',\n",
       " 'so',\n",
       " 'take',\n",
       " 'buying',\n",
       " 'love',\n",
       " 'steak',\n",
       " 'please',\n",
       " 'him',\n",
       " 'dog',\n",
       " 'food',\n",
       " 'ate',\n",
       " 'licks',\n",
       " 'I',\n",
       " 'not',\n",
       " 'flea',\n",
       " 'has',\n",
       " 'park']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "listOPosts, listClasses = loadDataSet()\n",
    "myVocabList = createVocabList(listOPosts)\n",
    "myVocabList"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "setOfWords2Vec(myVocabList, listOPosts[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "setOfWords2Vec(myVocabList, listOPosts[3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 参数：词向量矩阵，言论类别标签\n",
    "# 返回值：每个非侮辱性词汇概率，每个侮辱性词汇概率，侮辱性言论概率\n",
    "def trainNB0(trainMatrix, trainCategory):\n",
    "    numTrainDocs = len(trainMatrix)  # 言论条数\n",
    "    numWords = len(trainMatrix[0])  # 单条词向量数量\n",
    "    pAbusive = sum(trainCategory) / float(numTrainDocs)  # 侮辱性言论出现的概率\n",
    "    \n",
    "    p0Num = ones(numWords)\n",
    "    p1Num = ones(numWords)  # 每个侮辱性词汇出现的次数（初始化为1防止概率为0）\n",
    "    p0Denom = 2.0\n",
    "    p1Denom = 2.0  # 侮辱性词汇出现的总次数\n",
    "    \n",
    "    for i in range(numTrainDocs):\n",
    "        if trainCategory[i] == 1:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1Denom += sum(trainMatrix[i])\n",
    "        else:\n",
    "            p0Num += trainMatrix[i]\n",
    "            p0Denom += sum(trainMatrix[i])\n",
    "    p1Vect = log(p1Num / p1Denom)  # 每个侮辱性词汇出现的概率（加log防止下溢出）\n",
    "    p0Vect = log(p0Num / p0Denom)\n",
    "    return p0Vect, p1Vect, pAbusive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "listOPosts, listClasses = loadDataSet()\n",
    "myVocabList = createVocabList(listOPosts)\n",
    "\n",
    "trainMat = []  # 词向量矩阵\n",
    "for postinDoc in listOPosts:\n",
    "    trainMat.append(setOfWords2Vec(myVocabList, postinDoc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([-3.25809654, -2.56494936, -3.25809654, -2.56494936, -2.56494936,\n",
       "        -3.25809654, -3.25809654, -2.56494936, -2.56494936, -2.56494936,\n",
       "        -3.25809654, -2.56494936, -2.56494936, -1.87180218, -3.25809654,\n",
       "        -2.56494936, -2.56494936, -3.25809654, -3.25809654, -2.56494936,\n",
       "        -2.56494936, -2.56494936, -2.15948425, -2.56494936, -3.25809654,\n",
       "        -2.56494936, -2.56494936, -2.56494936, -3.25809654, -2.56494936,\n",
       "        -2.56494936, -3.25809654]),\n",
       " array([-2.35137526, -3.04452244, -1.94591015, -3.04452244, -2.35137526,\n",
       "        -2.35137526, -2.35137526, -3.04452244, -2.35137526, -3.04452244,\n",
       "        -1.65822808, -3.04452244, -3.04452244, -3.04452244, -2.35137526,\n",
       "        -3.04452244, -3.04452244, -2.35137526, -2.35137526, -3.04452244,\n",
       "        -3.04452244, -3.04452244, -2.35137526, -1.94591015, -2.35137526,\n",
       "        -3.04452244, -3.04452244, -3.04452244, -2.35137526, -3.04452244,\n",
       "        -3.04452244, -2.35137526]),\n",
       " 0.5)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p0V, p1V, pAb = trainNB0(trainMat, listClasses)\n",
    "p0V, p1V, pAb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):\n",
    "    p1 = sum(vec2Classify * p1Vec) + log(pClass1)\n",
    "    p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)\n",
    "    if p1 > p0:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "    \n",
    "def testingNB():\n",
    "    listOPosts, listClasses = loadDataSet()\n",
    "    myVocabList = createVocabList(listOPosts)\n",
    "    trainMat = []\n",
    "    for postinDoc in listOPosts:\n",
    "        trainMat.append(setOfWords2Vec(myVocabList, postinDoc))\n",
    "    p0V, p1V, pAb = trainNB0(array(trainMat), array(listClasses))\n",
    "    \n",
    "    testEntry = ['love', 'my', 'dalmation']\n",
    "    thisDoc = array(setOfWords2Vec(myVocabList, testEntry))\n",
    "    print(testEntry, 'classified as: ', classifyNB(thisDoc, p0V, p1V, pAb))\n",
    "    \n",
    "    testEntry = ['stupid', 'garbage']\n",
    "    thisDoc = array(setOfWords2Vec(myVocabList, testEntry))\n",
    "    print(testEntry, 'classified as: ', classifyNB(thisDoc, p0V, p1V, pAb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['love', 'my', 'dalmation'] classified as:  0\n",
      "['stupid', 'garbage'] classified as:  1\n"
     ]
    }
   ],
   "source": [
    "testingNB()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 词袋模型\n",
    "def bagOfWordsVecMN(vocabList, inputSet):\n",
    "    returnVec = [0] * len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] += 1\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实例：过滤垃圾邮件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def textParse(bigString):\n",
    "    import re\n",
    "    listOfTokens = re.split(r'\\W+', bigString)\n",
    "    return [tok.lower() for tok in listOfTokens if len(tok) > 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'--- Codeine 15mg -- 30 for $203.70 -- VISA Only!!! --\\n\\n-- Codeine (Methylmorphine) is a narcotic (opioid) pain reliever\\n-- We have 15mg & 30mg pills -- 30/15mg for $203.70 - 60/15mg for $385.80 - 90/15mg for $562.50 -- VISA Only!!! ---'"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bigString = open('email/spam/1.txt').read()\n",
    "bigString"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['codeine',\n",
       " '15mg',\n",
       " 'for',\n",
       " '203',\n",
       " 'visa',\n",
       " 'only',\n",
       " 'codeine',\n",
       " 'methylmorphine',\n",
       " 'narcotic',\n",
       " 'opioid',\n",
       " 'pain',\n",
       " 'reliever',\n",
       " 'have',\n",
       " '15mg',\n",
       " '30mg',\n",
       " 'pills',\n",
       " '15mg',\n",
       " 'for',\n",
       " '203',\n",
       " '15mg',\n",
       " 'for',\n",
       " '385',\n",
       " '15mg',\n",
       " 'for',\n",
       " '562',\n",
       " 'visa',\n",
       " 'only']"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "textParse(bigString)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spamTest():\n",
    "    docList = []  # 存储每封邮件的词列表集合\n",
    "    classList = []  # 邮件类别标签集合\n",
    "    fullText = []  # 所有邮件的词的集合\n",
    "    for i in range(1, 26):\n",
    "        wordList = textParse(open('email/spam/%d.txt' % i).read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(1)\n",
    "        \n",
    "        wordList = textParse(open('email/ham/%d.txt' % i).read())\n",
    "        docList.append(wordList)\n",
    "        fullText.extend(wordList)\n",
    "        classList.append(0)\n",
    "        \n",
    "    vocabList = createVocabList(docList)\n",
    "    trainingSet = list(range(50))  # 存储标号的训练集\n",
    "    testSet = []  # 存储标号的测试集\n",
    "    \n",
    "    for i in range(10):  # 随机将10个样例加入测试集\n",
    "        randIndex = int(random.uniform(0, len(trainingSet)))  # 返回区间随机数\n",
    "        testSet.append(trainingSet[randIndex])\n",
    "        del trainingSet[randIndex]\n",
    "    \n",
    "    trainMat = []\n",
    "    trainClasses = []\n",
    "    for docIndex in trainingSet:  # 构建训练集词向量矩阵\n",
    "        trainMat.append(setOfWords2Vec(vocabList, docList[docIndex]))\n",
    "        trainClasses.append(classList[docIndex])\n",
    "        \n",
    "    p0V, p1V, pSpam = trainNB0(array(trainMat), array(trainClasses))\n",
    "    \n",
    "    errorCount = 0\n",
    "    for docIndex in testSet:\n",
    "        wordVector = setOfWords2Vec(vocabList, docList[docIndex])\n",
    "        if classifyNB(array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:\n",
    "            errorCount += 1\n",
    "    print('the error rate is: ', float(errorCount) / len(testSet))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the error rate is:  0.1\n"
     ]
    }
   ],
   "source": [
    "spamTest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
